/***************************************************************************************************
 * Copyright (c) 2017-2020, NVIDIA CORPORATION.  All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 *modification, are permitted provided that the following conditions are met:
 *     * Redistributions of source code must retain the above copyright notice,
 *this list of conditions and the following disclaimer.
 *     * Redistributions in binary form must reproduce the above copyright
 *notice, this list of conditions and the following disclaimer in the
 *documentation and/or other materials provided with the distribution.
 *     * Neither the name of the NVIDIA CORPORATION nor the names of its
 *contributors may be used to endorse or promote products derived from this
 *software without specific prior written permission.
 *
 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
 *AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 *IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
 *DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY DIRECT,
 *INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
 * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
 *DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
 *OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TOR (INCLUDING
 *NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 *EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 *
 **************************************************************************************************/
/* \file
  \brief Defines operations for all CONV operation kinds in CUTLASS Library
*/

#pragma once

#include <iostream>
#include <sstream>
#include <cstring>

#include "cutlass/cutlass.h"

#include "cutlass/library/library.h"
#include "cutlass/library/manifest.h"
#include "cutlass/library/util.h"
#include "library_internal.h"

#include "cutlass/util/reference/host/convolution.h"
#include "cutlass/util/reference/device/convolution.h"

///////////////////////////////////////////////////////////////////////////////////////////////////

namespace cutlass {
namespace library {

///////////////////////////////////////////////////////////////////////////////////////////////////

namespace detail {

template <Provider kProvider, conv::Operator ConvolutionalOperator, int ConvDim,
          typename ElementA_, typename LayoutA_, typename ElementB_,
          typename LayoutB_, typename ElementC_, typename LayoutC_,
          typename ElementCompute_,
          typename ElementAccumulator_ = ElementCompute_,
          typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
          typename InnerProductOp_ = multiply_add<ElementAccumulator_> >
struct ConvReferenceDispatcher;

/// Dispatcher for Conv2d (partially specialied for kConvDim == 2)
template <Provider kProvider, conv::Operator kConvolutionalOperator,
          typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator,
          typename ConvertOp, typename InnerProductOp>
struct ConvReferenceDispatcher<kProvider, kConvolutionalOperator, 2, ElementA,
                               LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                               ElementCompute, ElementAccumulator, ConvertOp,
                               InnerProductOp> {
    static Status dispatch(void const* configuration, ElementA* ptr_A,
                           ElementB* ptr_B, ElementC* ptr_C, ElementC* ptr_D,
                           ElementCompute alpha, ElementCompute beta,
                           cudaStream_t stream = nullptr) {
        Conv2dConfiguration const& config =
                *static_cast<Conv2dConfiguration const*>(configuration);

        ConvKind const conv_kind = ConvKindMap<kConvolutionalOperator>::kId;

        if (kProvider == Provider::kReferenceHost) {
            cutlass::reference::host::Conv2d<ElementA, LayoutA, ElementB,
                                             LayoutB, ElementC, LayoutC,
                                             ElementCompute, ElementAccumulator,
                                             ConvertOp, InnerProductOp>(
                    kConvolutionalOperator, config.problem_size,
                    {ptr_A, config.layout_a(conv_kind)},
                    {ptr_B, config.layout_b(conv_kind)},
                    {ptr_C, config.layout_c(conv_kind)},
                    {ptr_D, config.layout_c(conv_kind)}, alpha, beta);

            return Status::kSuccess;
        } else if (kProvider == Provider::kReferenceDevice) {
            return cutlass::reference::device::Conv2d<
                    ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                    ElementCompute, ElementAccumulator, ConvertOp,
                    InnerProductOp>(kConvolutionalOperator, config.problem_size,
                                    {ptr_A, config.layout_a(conv_kind)},
                                    {ptr_B, config.layout_b(conv_kind)},
                                    {ptr_C, config.layout_c(conv_kind)},
                                    {ptr_D, config.layout_c(conv_kind)}, alpha,
                                    beta, stream);
        }
        return Status::kErrorNotSupported;
    }
};

/// Dispatcher for Conv3d (partially specialized for kConvDim == 3)
template <Provider kProvider, conv::Operator kConvolutionalOperator,
          typename ElementA, typename LayoutA, typename ElementB,
          typename LayoutB, typename ElementC, typename LayoutC,
          typename ElementCompute, typename ElementAccumulator,
          typename ConvertOp, typename InnerProductOp>
struct ConvReferenceDispatcher<kProvider, kConvolutionalOperator, 3, ElementA,
                               LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                               ElementCompute, ElementAccumulator, ConvertOp,
                               InnerProductOp> {
    static Status dispatch(void const* configuration, ElementA* ptr_A,
                           ElementB* ptr_B, ElementC* ptr_C, ElementC* ptr_D,
                           ElementCompute alpha, ElementCompute beta,
                           cudaStream_t stream = nullptr) {
        Conv3dConfiguration const& config =
                *static_cast<Conv3dConfiguration const*>(configuration);

        ConvKind const conv_kind = ConvKindMap<kConvolutionalOperator>::kId;

        if (kProvider == Provider::kReferenceHost) {
            cutlass::reference::host::Conv3d<ElementA, LayoutA, ElementB,
                                             LayoutB, ElementC, LayoutC,
                                             ElementCompute, ElementAccumulator,
                                             ConvertOp, InnerProductOp>(
                    kConvolutionalOperator, config.problem_size,
                    {ptr_A, config.layout_a(conv_kind)},
                    {ptr_B, config.layout_b(conv_kind)},
                    {ptr_C, config.layout_c(conv_kind)},
                    {ptr_D, config.layout_c(conv_kind)}, alpha, beta);

            return Status::kSuccess;
        } else if (kProvider == Provider::kReferenceDevice) {
            return cutlass::reference::device::Conv3d<
                    ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC,
                    ElementCompute, ElementAccumulator, ConvertOp,
                    InnerProductOp>(kConvolutionalOperator, config.problem_size,
                                    {ptr_A, config.layout_a(conv_kind)},
                                    {ptr_B, config.layout_b(conv_kind)},
                                    {ptr_C, config.layout_c(conv_kind)},
                                    {ptr_D, config.layout_c(conv_kind)}, alpha,
                                    beta, stream);
        }
        return Status::kErrorNotSupported;
    }
};

}  // namespace detail

///////////////////////////////////////////////////////////////////////////////////////////////////

template <Provider Provider_, conv::Operator ConvolutionalOperator, int ConvDim,
          typename ElementA_, typename LayoutA_, typename ElementB_,
          typename LayoutB_, typename ElementC_, typename LayoutC_,
          typename ElementCompute_,
          typename ElementAccumulator_ = ElementCompute_,
          typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
          typename InnerProductOp_ = multiply_add<ElementAccumulator_> >
class ConvReferenceOperation : public Operation {
public:
    static Provider const kProvider = Provider_;
    static conv::Operator const kConvolutionalOperator = ConvolutionalOperator;
    static int const kConvDim = ConvDim;

    using ElementA = ElementA_;
    using LayoutA = LayoutA_;
    using ElementB = ElementB_;
    using LayoutB = LayoutB_;
    using ElementC = ElementC_;
    using LayoutC = LayoutC_;
    using ElementCompute = ElementCompute_;
    using ElementAccumulator = ElementAccumulator_;
    using ConvertOp = ConvertOp_;
    using InnerProductOp = InnerProductOp_;

protected:
    /// Storage for the name string
    std::string name_;

    ///
    ConvDescription description_;

public:
    /// Constructor
    ConvReferenceOperation() {
        // Basic information
        description_.provider = kProvider;
        description_.kind = (kConvDim == 2 ? OperationKind::kConv2d
                                           : OperationKind::kConv3d);
        description_.conv_kind = ConvKindMap<kConvolutionalOperator>::kId;
        description_.conv_dim = kConvDim;

        // Tensor description
        description_.A = make_TensorDescription<ElementA, LayoutA>();
        description_.B = make_TensorDescription<ElementB, LayoutB>();
        description_.C = make_TensorDescription<ElementC, LayoutC>();

        // Epilogue compute and accumulator type description
        description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;

        description_.tile_description.math_instruction.element_accumulator =
                NumericTypeMap<ElementAccumulator>::kId;

        // Iterator algorithm for convolution reference
        description_.iterator_algorithm = IteratorAlgorithmID::kNone;

        // Compute capability for convolution reference
        description_.tile_description.minimum_compute_capability =
                (kProvider == Provider::kReferenceDevice ? 50 : 0);

        description_.tile_description.maximum_compute_capability = 1024;

        // Procedural name
        std::stringstream ss;

        ss << "conv" << kConvDim << "d_" << to_string(description_.conv_kind)
           << "_reference_" << to_string(description_.provider) << "_"
           << to_string(description_.A.element)
           << to_string(description_.A.layout) << "_"
           << to_string(description_.B.element)
           << to_string(description_.B.layout) << "_"
           << to_string(description_.C.element)
           << to_string(description_.C.layout) << "_"
           << to_string(description_.tile_description.math_instruction
                                .element_accumulator);

        name_ = ss.str();

        description_.name = name_.c_str();

        // Epilogue compute and accumulator type description
        description_.element_epilogue = NumericTypeMap<ElementCompute>::kId;

        description_.tile_description.math_instruction.element_accumulator =
                NumericTypeMap<ElementAccumulator>::kId;
    }

    /// Returns the description of the GEMM operation
    virtual OperationDescription const& description() const {
        return description_;
    }

    virtual Status can_implement(void const* configuration,
                                 void const* arguments) const {
        return Status::kSuccess;
    }

    virtual uint64_t get_host_workspace_size(void const* configuration) const {
        switch (kConvDim) {
            case 2:
                return sizeof(Conv2dConfiguration);
            case 3:
                return sizeof(Conv3dConfiguration);
            default:
                break;
        }

        return 0;
    }

    virtual uint64_t get_device_workspace_size(
            void const* configuration) const {
        return 0;
    }

    virtual Status initialize(void const* configuration, void* host_workspace,
                              void* device_workspace = nullptr,
                              cudaStream_t stream = nullptr) const {
        std::memcpy(host_workspace, configuration,
                    get_host_workspace_size(configuration));

        return Status::kSuccess;
    }

    virtual Status run(void const* arguments, void* host_workspace,
                       void* device_workspace = nullptr,
                       cudaStream_t stream = nullptr) const {
        ConvArguments const& args =
                *static_cast<ConvArguments const*>(arguments);

        ElementCompute alpha;
        ElementCompute beta;

        alpha = *static_cast<ElementCompute const*>(args.alpha);
        beta = *static_cast<ElementCompute const*>(args.beta);

        // TODO - respect pointer mode

        // Invoke 2D or 3D convolution
        return detail::ConvReferenceDispatcher<
                kProvider, kConvolutionalOperator, kConvDim, ElementA, LayoutA,
                ElementB, LayoutB, ElementC, LayoutC, ElementCompute,
                ElementAccumulator, ConvertOp, InnerProductOp>::
                dispatch(host_workspace,
                         static_cast<ElementA*>(const_cast<void*>(args.A)),
                         static_cast<ElementB*>(const_cast<void*>(args.B)),
                         static_cast<ElementC*>(const_cast<void*>(args.C)),
                         static_cast<ElementC*>(args.D), alpha, beta, stream);
    }
};

///////////////////////////////////////////////////////////////////////////////////////////////////

/// Constructs Fprop reference operators.
template <int kConvDim, typename ElementA_, typename LayoutA_,
          typename ElementB_, typename LayoutB_, typename ElementC_,
          typename LayoutC_, typename ElementCompute_,
          typename ElementAccumulator_ = ElementCompute_,
          typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
          typename InnerProductOp_ = multiply_add<ElementAccumulator_> >
void make_conv_fprop(Manifest& manifest) {
    manifest.append(new ConvReferenceOperation<
                    Provider::kReferenceHost, conv::Operator::kFprop, kConvDim,
                    ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
                    LayoutC_, ElementCompute_, ElementAccumulator_, ConvertOp_,
                    InnerProductOp_>);

    manifest.append(new ConvReferenceOperation<
                    Provider::kReferenceDevice, conv::Operator::kFprop,
                    kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_,
                    ElementC_, LayoutC_, ElementCompute_, ElementAccumulator_,
                    ConvertOp_, InnerProductOp_>);
}

/// Constructs Dgrad and Wgrad reference operators.
template <int kConvDim, typename ElementA_, typename LayoutA_,
          typename ElementB_, typename LayoutB_, typename ElementC_,
          typename LayoutC_, typename ElementCompute_,
          typename ElementAccumulator_ = ElementCompute_,
          typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
          typename InnerProductOp_ = multiply_add<ElementAccumulator_> >
void make_conv_backwards(Manifest& manifest) {
    manifest.append(new ConvReferenceOperation<
                    Provider::kReferenceHost, conv::Operator::kDgrad, kConvDim,
                    ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
                    LayoutC_, ElementCompute_, ElementAccumulator_, ConvertOp_,
                    InnerProductOp_>);

    manifest.append(new ConvReferenceOperation<
                    Provider::kReferenceDevice, conv::Operator::kDgrad,
                    kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_,
                    ElementC_, LayoutC_, ElementCompute_, ElementAccumulator_,
                    ConvertOp_, InnerProductOp_>);

    manifest.append(new ConvReferenceOperation<
                    Provider::kReferenceHost, conv::Operator::kWgrad, kConvDim,
                    ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_,
                    LayoutC_, ElementCompute_, ElementAccumulator_, ConvertOp_,
                    InnerProductOp_>);

    manifest.append(new ConvReferenceOperation<
                    Provider::kReferenceDevice, conv::Operator::kWgrad,
                    kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_,
                    ElementC_, LayoutC_, ElementCompute_, ElementAccumulator_,
                    ConvertOp_, InnerProductOp_>);
}

/// Six operators for the price of one.
template <int kConvDim, typename ElementA_, typename LayoutA_,
          typename ElementB_, typename LayoutB_, typename ElementC_,
          typename LayoutC_, typename ElementCompute_,
          typename ElementAccumulator_ = ElementCompute_,
          typename ConvertOp_ = NumericConverter<ElementC_, ElementCompute_>,
          typename InnerProductOp_ = multiply_add<ElementAccumulator_> >
void make_conv_all(Manifest& manifest) {
    make_conv_fprop<kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_,
                    ElementC_, LayoutC_, ElementCompute_, ElementAccumulator_,
                    ConvertOp_, InnerProductOp_>(manifest);

    make_conv_backwards<kConvDim, ElementA_, LayoutA_, ElementB_, LayoutB_,
                        ElementC_, LayoutC_, ElementCompute_,
                        ElementAccumulator_, ConvertOp_, InnerProductOp_>(
            manifest);
}

///////////////////////////////////////////////////////////////////////////////////////////////////

}  // namespace library
}  // namespace cutlass

///////////////////////////////////////////////////////////////////////////////////////////////////
